{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "26be3612-0f3d-44bb-b052-fe87ebfe9777",
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore\n",
    "# "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "58ff91ca-ce92-43d0-ae8b-4e9e89e193f6",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 0.596 seconds.\n",
      "Prefix dict has been built successfully.\n",
      "The following parameters in models are missing parameter:\n",
      "['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']\n"
     ]
    }
   ],
   "source": [
    "from mindnlp.dataset import load_dataset\n",
    "from mindnlp.transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
    "from mindnlp.engine import set_seed\n",
    "from mindnlp.peft import get_peft_model, MultitaskPromptTuningConfig, TaskType, MultitaskPromptTuningInit\n",
    "\n",
    "set_seed(42)\n",
    "\n",
    "model_name = \"google/flan-t5-base\"\n",
    "\n",
    "peft_config = MultitaskPromptTuningConfig(\n",
    "    tokenizer_name_or_path=model_name,\n",
    "    num_tasks=2,\n",
    "    task_type=TaskType.SEQ_2_SEQ_LM,\n",
    "    prompt_tuning_init=MultitaskPromptTuningInit.TEXT,\n",
    "    num_virtual_tokens=50,\n",
    "    num_transformer_submodules=1,\n",
    "    prompt_tuning_init_text=\"classify the following into either positive or negative, or entailment, neutral or contradiction:\",\n",
    ")\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
    "model = get_peft_model(model, peft_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "38769636-59cb-48f3-9267-d1922bd16346",
   "metadata": {},
   "outputs": [],
   "source": [
    "sst_dataset = load_dataset(\"sst2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1325c3fb-c437-4d9c-9614-2f47c8abe840",
   "metadata": {},
   "outputs": [],
   "source": [
    "multi_nli_dataset = load_dataset(\"multi_nli\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a8d4ca30-e14d-4a55-b7d6-6f5f05fd0fe8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from mindnlp.dataset import BaseMapFunction\n",
    "\n",
    "class SST2Map(BaseMapFunction):\n",
    "    def __call__(self, idx, sentence, label):\n",
    "        input = str(sentence).strip() + \"</s>\"\n",
    "        output = (f\"positive{tokenizer.eos_token}\" if label == 1 else f\"negative{tokenizer.eos_token}\")\n",
    "        input = tokenizer(input, add_special_tokens=False)\n",
    "        output = tokenizer(output, add_special_tokens=False).input_ids\n",
    "        output = np.where(np.equal(output, tokenizer.pad_token_id), -100, output)\n",
    "        task_ids = 0\n",
    "        return input.input_ids, input.attention_mask, output, task_ids\n",
    "\n",
    "class MNLIMap(BaseMapFunction):\n",
    "    def __call__(self, promptID, pairID, premise, premise_binary_parse, premise_parse, hypothesis,\n",
    "                 hypothesis_binary_parse, hypothesis_parse, genre, label):\n",
    "        input = str(premise).strip() + \" \" + str(hypothesis).strip() + \"</s>\"\n",
    "        if label == 0:\n",
    "            output = f\"entailment{tokenizer.eos_token}\"\n",
    "        elif label == 1:\n",
    "            output = f\"neutral{tokenizer.eos_token}\"\n",
    "        else:\n",
    "            output = f\"contradiction{tokenizer.eos_token}\"\n",
    "\n",
    "        input = tokenizer(input, add_special_tokens=False)\n",
    "        output = tokenizer(output, add_special_tokens=False).input_ids\n",
    "        output = np.where(np.equal(output, tokenizer.pad_token_id), -100, output)\n",
    "        task_ids = 1\n",
    "        return input.input_ids, input.attention_mask, output, task_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "17e8e74c-7171-45af-9808-cbb8056563ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_combine_dataset(shuffle=False, batch_size=8):\n",
    "    sst_input_colums=['idx', 'sentence', 'label']\n",
    "    mnli_input_colums=['promptID', 'pairID', 'premise', 'premise_binary_parse', 'premise_parse', 'hypothesis',\n",
    "                       'hypothesis_binary_parse', 'hypothesis_parse', 'genre', 'label']\n",
    "    output_columns=['input_ids', 'attention_mask', 'labels', 'task_ids']\n",
    "    sst_train = sst_dataset['train']\n",
    "    sst_train = sst_train.map(SST2Map(sst_input_colums, output_columns), sst_input_colums, output_columns)\n",
    "\n",
    "    mnli_train = multi_nli_dataset['train']\n",
    "    mnli_train = mnli_train.map(MNLIMap(mnli_input_colums, output_columns), mnli_input_colums, output_columns)\n",
    "    train_dataset = sst_train + mnli_train\n",
    "    train_dataset = train_dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),\n",
    "                                                                     'attention_mask': (None, 0),\n",
    "                                                                     'labels': (None, tokenizer.pad_token_id)})\n",
    "    if shuffle:\n",
    "        train_dataset = train_dataset.shuffle(1024)\n",
    "\n",
    "    return train_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bb70f7c2-5e70-4dc7-9549-4f80b09a5c94",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sst_dataset(mode, shuffle=False, batch_size=8):\n",
    "    sst_input_colums=['idx', 'sentence', 'label']\n",
    "    output_columns=['input_ids', 'attention_mask', 'labels', 'task_ids']\n",
    "    sst_data = sst_dataset[mode]\n",
    "    sst_data = sst_data.map(SST2Map(sst_input_colums, output_columns), sst_input_colums, output_columns)\n",
    "    if shuffle:\n",
    "        sst_data = sst_data.shuffle(64)\n",
    "    sst_data = sst_data.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),\n",
    "                                                             'attention_mask': (None, 0),\n",
    "                                                             'labels': (None, tokenizer.pad_token_id)})\n",
    "    return sst_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "cc1b754c-9958-4737-a915-ba08d03fa34e",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = get_combine_dataset(shuffle=True, batch_size=8)\n",
    "eval_dataset = get_sst_dataset('validation', batch_size=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "09f63e8d-1d28-4869-9672-0e851b092493",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': Tensor(shape=[8, 26], dtype=Int64, value=\n",
       " [[  151,    12,  4068 ...     0,     0,     0],\n",
       "  [   19,    38,     3 ...    40,   814,     1],\n",
       "  [   34,    19,    46 ...     1,     0,     0],\n",
       "  ...\n",
       "  [12430,   920,     1 ...     0,     0,     0],\n",
       "  [ 4657,    95,    66 ...     0,     0,     0],\n",
       "  [   81,   985,    13 ...     0,     0,     0]]),\n",
       " 'attention_mask': Tensor(shape=[8, 26], dtype=Int64, value=\n",
       " [[1, 1, 1 ... 0, 0, 0],\n",
       "  [1, 1, 1 ... 1, 1, 1],\n",
       "  [1, 1, 1 ... 1, 0, 0],\n",
       "  ...\n",
       "  [1, 1, 1 ... 0, 0, 0],\n",
       "  [1, 1, 1 ... 0, 0, 0],\n",
       "  [1, 1, 1 ... 0, 0, 0]]),\n",
       " 'labels': Tensor(shape=[8, 2], dtype=Int64, value=\n",
       " [[2841,    1],\n",
       "  [2841,    1],\n",
       "  [1465,    1],\n",
       "  ...\n",
       "  [2841,    1],\n",
       "  [1465,    1],\n",
       "  [1465,    1]]),\n",
       " 'task_ids': Tensor(shape=[8], dtype=Int64, value= [0, 0, 0, 0, 0, 0, 0, 0])}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(train_dataset.create_dict_iterator())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe0aec7b-f61e-4b00-a90e-c1201dc1f84c",
   "metadata": {},
   "source": [
    "## source training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "cceecc94-f43a-4f62-8d45-926f2f02f36d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from mindspore.experimental.optim.adamw import AdamW\n",
    "from mindnlp.modules.optimization import get_cosine_schedule_with_warmup\n",
    "from tqdm import tqdm\n",
    "from sklearn.metrics import f1_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "eae5516b-73ab-44a8-a083-4e8de6127f30",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "POSITIVE_TOKEN_ID = tokenizer(\" positive\", add_special_tokens=False)[\"input_ids\"][0]\n",
    "NEGATIVE_TOKEN_ID = tokenizer(\" negative\", add_special_tokens=False)[\"input_ids\"][0]\n",
    "\n",
    "\n",
    "def classify(batch):\n",
    "    # we pass labels here since we need to generate and peft doesn't support generation yet.\n",
    "    # No clue how to get around this\n",
    "    scores = model(**batch).logits\n",
    "    preds = []\n",
    "    for i in range(scores.shape[0]):\n",
    "        if scores[i, 0, POSITIVE_TOKEN_ID] > scores[i, 0, NEGATIVE_TOKEN_ID]:\n",
    "            preds.append(POSITIVE_TOKEN_ID)\n",
    "        else:\n",
    "            preds.append(NEGATIVE_TOKEN_ID)\n",
    "    return preds\n",
    "\n",
    "\n",
    "def evaluate(model, data):\n",
    "    model.set_train(False)\n",
    "    loss = 0\n",
    "    preds = []\n",
    "    golds = []\n",
    "\n",
    "    total = data.get_dataset_size()\n",
    "    for batch in tqdm(data.create_dict_iterator(), total=total):\n",
    "        with mindspore._no_grad():\n",
    "            loss += model(**batch).loss\n",
    "        golds.extend(batch[\"labels\"][:, 0].tolist())\n",
    "        preds.extend(classify(batch))\n",
    "\n",
    "    return loss / total, f1_score(golds, preds, pos_label=POSITIVE_TOKEN_ID)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8225f6f5-4e3a-45dc-9b1c-d1c80ff76b83",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = AdamW(model.trainable_params(), lr=1e-4)\n",
    "scheduler = get_cosine_schedule_with_warmup(optimizer, 200, len(train_dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14c45e28-d4e1-41a1-9937-46585617d49b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 109/109 [00:29<00:00,  3.72it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "before source training, val loss = 14.657343, f1 = 0.31596091205211724\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                          | 35/57507 [00:26<9:02:49,  1.76it/s, train_loss=9.211916]"
     ]
    }
   ],
   "source": [
    "n = 1000\n",
    "step = 0\n",
    "\n",
    "val_loss, f1 = evaluate(model, eval_dataset)\n",
    "print(f\"\"\"before source training, val loss = {val_loss}, f1 = {f1}\"\"\")\n",
    "\n",
    "# training and evaluation\n",
    "def forward_fn(**batch):\n",
    "    outputs = model(**batch)\n",
    "    loss = outputs.loss\n",
    "    return loss\n",
    "\n",
    "grad_fn = mindspore.value_and_grad(forward_fn, None, model.trainable_params())\n",
    "\n",
    "def train_step(**batch):\n",
    "    loss, grads = grad_fn(**batch)\n",
    "    optimizer(grads)\n",
    "    return loss\n",
    "\n",
    "\n",
    "train_total = train_dataset.get_dataset_size()\n",
    "train_ = tqdm(train_dataset.create_dict_iterator(), total=train_total)\n",
    "\n",
    "for batch in train_:\n",
    "    if step % n == 0 and step != 0:\n",
    "        val_loss, f1 = evaluate(model, eval_dataset)\n",
    "        print(f\"\"\"step = {step}, val loss = {val_loss}, f1 = {f1}\"\"\")\n",
    "        model.save_pretrained(f\"checkpoints_source/{step}\")\n",
    "\n",
    "    model.set_train()\n",
    "    step += 1\n",
    "    loss = train_step(**batch)\n",
    "    scheduler.step()\n",
    "    train_.set_postfix(train_loss=loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74168ef3-66f3-41a7-a40b-7840b103fbf9",
   "metadata": {},
   "source": [
    "## target training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b09fd456-163e-4dc1-b24d-f2d0d349036c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_dataset = get_sst_dataset('train', shuffle=True, batch_size=8)\n",
    "eval_dataset = get_sst_dataset('validation', batch_size=8)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a539944-f16c-4c3f-bb4a-7b5d9a6042e2",
   "metadata": {},
   "source": [
    "#### create a fresh model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5520d904-aa6c-4654-9335-ed4e7d76cba2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "peft_config = MultitaskPromptTuningConfig(\n",
    "    tokenizer_name_or_path=model_name,\n",
    "    num_tasks=1,\n",
    "    task_type=TaskType.SEQ_2_SEQ_LM,\n",
    "    prompt_tuning_init=MultitaskPromptTuningInit.EXACT_SOURCE_TASK,\n",
    "    prompt_tuning_init_state_dict_path=\"checkpoints_source/50000/adapter_model.bin\",\n",
    "    num_virtual_tokens=50,\n",
    "    num_transformer_submodules=1,\n",
    ")\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
    "model = get_peft_model(model, peft_config)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfa39c2d-d1c5-4ed4-90f8-26e8e324371c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "optimizer = AdamW(model.trainable_params(), lr=1e-4)\n",
    "scheduler = get_cosine_schedule_with_warmup(optimizer, 200, len(train_dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "134d053c-83aa-44fc-acf1-880987f9a30b",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 1000\n",
    "step = 0\n",
    "\n",
    "val_loss, f1 = evaluate(model, eval_dataset)\n",
    "print(f\"\"\"before source training, val loss = {val_loss}, f1 = {f1}\"\"\")\n",
    "\n",
    "# training and evaluation\n",
    "def forward_fn(**batch):\n",
    "    outputs = model(**batch)\n",
    "    loss = outputs.loss\n",
    "    return loss\n",
    "\n",
    "grad_fn = mindspore.value_and_grad(forward_fn, None, model.trainable_params())\n",
    "\n",
    "def train_step(**batch):\n",
    "    loss, grads = grad_fn(**batch)\n",
    "    optimizer(grads)\n",
    "    return loss\n",
    "\n",
    "\n",
    "train_total = train_dataset.get_dataset_size()\n",
    "train_ = tqdm(train_dataset.create_dict_iterator(), total=train_total)\n",
    "\n",
    "for batch in train_:\n",
    "    if step % n == 0 and step != 0:\n",
    "        val_loss, f1 = evaluate(model, eval_dataset)\n",
    "        print(f\"\"\"step = {step}, val loss = {val_loss}, f1 = {f1}\"\"\")\n",
    "        model.save_pretrained(f\"checkpoints_source/{step}\")\n",
    "\n",
    "    model.set_train()\n",
    "    step += 1\n",
    "    loss = train_step(**batch)\n",
    "    scheduler.step()\n",
    "    train_.set_postfix(train_loss=loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6a6eeda-1e09-49a6-8845-cd96c8573145",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# load last checkpoint for now\n",
    "from mindnlp.peft import set_peft_model_state_dict\n",
    "\n",
    "sd_6000 = mindspore.load_checkpoint(\"checkpoints_target/6000/adapter_model.ckpt\")\n",
    "set_peft_model_state_dict(model, sd_6000)\n",
    "\n",
    "# evaluate val\n",
    "val_loss, f1 = evaluate(model, eval_dataset)\n",
    "print(\n",
    "    f\"\"\"\n",
    "final\n",
    "val loss = {val_loss}\n",
    "f1 = {f1}\"\"\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mindspore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
